Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented Oct 2, 2025

This PR adds a canonicalizer to vector.step that folds vector.step iff the result of the fold is a splat value. An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.

I do wonder if vector.step might be better represented as some sort of attribute in the arith dialect, like %step = arith.constant iota<32> : vector<32xindex>.

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

This PR adds a canonicalizer to vector.step that folds vector.step iff the result of the fold is a splat value. An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.

I do wonder if vector.step might be better represented as some sort of attribute in the arith dialect, like %step = arith.constant iota<32> : vector<32xindex>.


Full diff: https://github.com/llvm/llvm-project/pull/161615.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+95)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-step.mlir (+379)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..dbb5d0f659159 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3045,6 +3045,7 @@ def Vector_StepOp : Vector_Op<"step", [
   }];
   let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
   let assemblyFormat = "attr-dict `:` type($result)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_YieldOp : Vector_Op<"yield", [
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eb4686997c1b9..306be186308b0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7524,6 +7524,101 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+///
+/// as,
+///
+/// %out = arith.constant dense<false> : vector<3xi1>.
+///
+/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
+/// false at ALL indices we fold. If the constant was 1, then
+/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
+/// preferring the 'compact' vector.step representation.
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+    for (auto &use : stepOp.getResult().getUses()) {
+      if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
+        const unsigned stepOperandNumber = use.getOperandNumber();
+
+        // arith.cmpi canonicalizer makes constants final operands.
+        if (stepOperandNumber != 0)
+          continue;
+
+        // Check that operand 1 is a constant.
+        unsigned constOperandNumber = 1;
+        Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+        auto maybeConstValue = getConstantIntValue(otherOperand);
+        if (!maybeConstValue.has_value())
+          continue;
+
+        int64_t constValue = maybeConstValue.value();
+        arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+        auto maybeSplat = [&]() -> std::optional<bool> {
+          // Handle ult (unsigned less than) and uge (unsigned greater equal).
+          if ((pred == arith::CmpIPredicate::ult ||
+               pred == arith::CmpIPredicate::uge) &&
+              stepSize <= constValue)
+            return pred == arith::CmpIPredicate::ult;
+
+          // Handle ule and ugt.
+          if ((pred == arith::CmpIPredicate::ule ||
+               pred == arith::CmpIPredicate::ugt) &&
+              stepSize <= constValue + 1)
+            return pred == arith::CmpIPredicate::ule;
+
+          // Handle eq and ne.
+          if ((pred == arith::CmpIPredicate::eq ||
+               pred == arith::CmpIPredicate::ne) &&
+              stepSize <= constValue)
+            return pred == arith::CmpIPredicate::ne;
+
+          return std::optional<bool>();
+        }();
+
+        if (!maybeSplat.has_value())
+          continue;
+
+        rewriter.setInsertionPointAfter(cmpiOp);
+
+        auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+        if (!type)
+          continue;
+
+        DenseElementsAttr boolAttr =
+            DenseElementsAttr::get(type, maybeSplat.value());
+        Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+                                                      type, boolAttr);
+
+        rewriter.replaceOp(cmpiOp, splat);
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<StepCompareFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
new file mode 100644
index 0000000000000..effeb3d9c093a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -0,0 +1,379 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+///===----------------------------------------------===//
+///  Tests of `StepCompareFolder`
+///===----------------------------------------------===//
+
+
+///===------------------------------------===//
+///  Tests of `ugt` (unsigned greater than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ugt_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 3 > [0, 1, 2] => true
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ugt_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 2 > [0, 1, 2] => not constant
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 1 > [0, 1, 2] => not constant
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 3 => false
+  %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 2 => false
+  %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 1 => not constant
+  %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `uge` (unsigned greater than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_uge_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 3 >= [0, 1, 2] => true
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 2 >= [0, 1, 2] => true
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 1 >= [0, 1, 2] => not constant
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 3 => false
+  %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_uge_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 2 => not constant
+  %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 1 => not constant
+  %1 = arith.cmpi uge, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+
+
+///===------------------------------------===//
+///  Tests of `ult` (unsigned less than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ult_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ult_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `ule` (unsigned less than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ule_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ule_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `eq` (equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_eq_constant_3
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_eq_constant_3() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_eq_constant_2
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_eq_constant_2() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `ne` (not equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ne_constant_3
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ne_constant_3() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ne_constant_2
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ne_constant_2() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.

FYI, we have some max vector sizes defined for insert/extract folders for this exact reason. I think it would work here too.

PatternRewriter &rewriter) const override {
const int64_t stepSize = stepOp.getResult().getType().getNumElements();

for (auto &use : stepOp.getResult().getUses()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell out the type since it's not immediately obvious based on the RHS

// Check that operand 1 is a constant.
unsigned constOperandNumber = 1;
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
auto maybeConstValue = getConstantIntValue(otherOperand);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell out the type

newling added a commit to iree-org/iree that referenced this pull request Oct 6, 2025
This is a 'direct' approach alternative to the upstream PR
llvm/llvm-project#161615.

It directly folds the case where the linalg.index is known to always be
less than the unpadded dimension size. This happens when
`partial_reduction` perfectly divides the reduction dimension.

The upstream approach (see link above) folds at a later stage, where,
after vectorization, there is a vector.step compared to a constant.

---------

Signed-off-by: James Newling <[email protected]>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % remaining nits

Comment on lines 7624 to 7626
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
/// false at ALL indices we fold. If the constant was 1, then
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
/// false at ALL indices we fold. If the constant was 1, then
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result is
/// false at ALL indices we fold. If the constant was 1, then
/// `[0, 1, 2]` > [1, 1, 1]` => `[false, false, true]` and we do fold, conservatively

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Personally I would reduce the number of tests by checking fewer _lhs vs _rhs variants. There's also one other [nit] inline 😅

Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
@newling
Copy link
Contributor Author

newling commented Oct 14, 2025

Personally I would reduce the number of tests by checking fewer _lhs vs _rhs variants.

I've reduced the number of tests by about 1/3x, by only testing 2 cases for each variant instead of 3.

@newling newling merged commit bea77ed into llvm:main Oct 15, 2025
10 checks passed
@dcaballe
Copy link
Contributor

Nice contribution!

Couldn't we make vector.step a ConstantLike op? That would make the arith.constant canonicalizations to also trigger for vector.step?

@Groverkss
Copy link
Member

Nice contribution!

Couldn't we make vector.step a ConstantLike op? That would make the arith.constant canonicalizations to also trigger for vector.step?

Interesting, do you mean make it something like : arith.constant iota<end> ? Or is there a ConstantLike interface (maybe we should add that if not?) I'll have a deeper look into the codebase too, but could use more info on what you had in mind.

@dcaballe
Copy link
Contributor

Or is there a ConstantLike interface (maybe we should add that if not?)

Yes, there is a ConstantLike trait that different constant ops are using. I'm not sure if we'd need to turn this into an interface but this is a path we can explore. That would indeed require an attribute for vector.step to encode the constant information. The key point, IMO, is that we shouldn't assume that the stride of the step is always 1 as this is something we want to parametrize in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants